{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Continuous Bag of Words (Deep CBOW) Text Classifier\n",
    "\n",
    "The code below implements a continuous bag of words text classifier.\n",
    "- We tokenize the text, create a vocabulary and encode each piece of text in the dataset\n",
    "- We create embeddings for inputs and sum them together\n",
    "- The resulting vector is fed to hidden neural network, which generates a new vector that is multiplied to a weights matrix\n",
    "- We then add the bias and obtain scores\n",
    "- The scores are applied a softmax to generate probabilities which are used for the final classification\n",
    "\n",
    "The code used in this notebook was inspired by code from the [official repo](https://github.com/neubig/nn4nlp-code) used in the [CMU Neural Networks for NLP class](http://www.phontron.com/class/nn4nlp2021/schedule.html) by [Graham Neubig](http://www.phontron.com/index.php). \n",
    "\n",
    "![img txt](https://github.com/dair-ai/ML-Notebooks/blob/main/img/deep_cbow.png?raw=true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "''' uncomment to download the data\n",
    "%%capture\n",
    "\n",
    "# download the files\n",
    "!wget https://raw.githubusercontent.com/neubig/nn4nlp-code/master/data/classes/dev.txt\n",
    "!wget https://raw.githubusercontent.com/neubig/nn4nlp-code/master/data/classes/test.txt\n",
    "!wget https://raw.githubusercontent.com/neubig/nn4nlp-code/master/data/classes/train.txt\n",
    "\n",
    "# create the data folders\n",
    "!mkdir data data/classes\n",
    "!cp dev.txt data/classes\n",
    "!cp test.txt data/classes\n",
    "!cp train.txt data/classes\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read and Process the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to read in data, process each line and split columns by \" ||| \"\n",
    "def read_data(filename):\n",
    "    data = []\n",
    "    with open(filename, 'r') as f:\n",
    "        for line in f:\n",
    "            line = line.lower().strip()\n",
    "            line = line.split(' ||| ')\n",
    "            data.append(line)\n",
    "    return data\n",
    "\n",
    "train_data = read_data('data/classes/train.txt')\n",
    "test_data = read_data('data/classes/test.txt')\n",
    "\n",
    "# creating the word and tag indices\n",
    "word_to_index = {}\n",
    "word_to_index[\"<unk>\"] = len(word_to_index) # add <UNK> to dictionary\n",
    "tag_to_index = {}\n",
    "\n",
    "# create word to index dictionary and tag to index dictionary from data\n",
    "def create_dict(data, check_unk=False):\n",
    "    for line in data:\n",
    "        for word in line[1].split(\" \"):\n",
    "            if check_unk == False:\n",
    "                if word not in word_to_index:\n",
    "                    word_to_index[word] = len(word_to_index)\n",
    "            else:\n",
    "                if word not in word_to_index:\n",
    "                    word_to_index[word] = word_to_index[\"<unk>\"]\n",
    "\n",
    "        if line[0] not in tag_to_index:\n",
    "            tag_to_index[line[0]] = len(tag_to_index)\n",
    "\n",
    "create_dict(train_data)\n",
    "create_dict(test_data, check_unk=True)\n",
    "\n",
    "# create word and tag tensors from data\n",
    "def create_tensor(data):\n",
    "    for line in data:\n",
    "        yield([word_to_index[word] for word in line[1].split(\" \")], tag_to_index[line[0]])\n",
    "\n",
    "train_data = list(create_tensor(train_data))\n",
    "test_data = list(create_tensor(test_data))\n",
    "\n",
    "number_of_words = len(word_to_index)\n",
    "number_of_tags = len(tag_to_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# create a simple neural network with embedding layer, bias, and xavier initialization\n",
    "class DeepCBoW(nn.Module):\n",
    "    def __init__(self, nwords, ntags, hidden_size, num_layers, emb_size):\n",
    "        super(DeepCBoW, self).__init__()\n",
    "\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        # layers\n",
    "        self.embedding = nn.Embedding(nwords, emb_size)\n",
    "        self.linears = nn.ModuleList([nn.Linear(emb_size if i ==0 else hidden_size, hidden_size) \\\n",
    "            for i in range(num_layers)])\n",
    "\n",
    "        # use xavier initialization for weights\n",
    "        nn.init.xavier_uniform_(self.embedding.weight)\n",
    "        for i in range(self.num_layers):\n",
    "            nn.init.xavier_uniform_(self.linears[i].weight)\n",
    "\n",
    "        # output layer\n",
    "        self.output_layer = nn.Linear(hidden_size, ntags)\n",
    "\n",
    "    def forward(self, x):\n",
    "        emb = self.embedding(x) # seq x emb_size\n",
    "        emb_sum = torch.sum(emb, dim=0) # emb_size\n",
    "        h = emb_sum.view(1, -1) # reshape to (1, emb_size)\n",
    "        for i in range(self.num_layers):\n",
    "            h = torch.tanh(self.linears[i](h))\n",
    "        out = self.output_layer(h) # 1 x ntags\n",
    "        return out\n",
    "\n",
    "HIDDEN_SIZE = 64\n",
    "NUM_LAYERS = 2 # hidden layers\n",
    "EMB_SIZE = 64\n",
    "model = DeepCBoW(number_of_words, number_of_tags, HIDDEN_SIZE, NUM_LAYERS, EMB_SIZE).to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "type = torch.LongTensor\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model.to(device)\n",
    "    type = torch.cuda.LongTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 | train loss/sent: 1.4293 | train accuracy: 0.3765 | test accuracy: 0.3941\n",
      "epoch: 2 | train loss/sent: 1.0343 | train accuracy: 0.5729 | test accuracy: 0.4127\n",
      "epoch: 3 | train loss/sent: 0.6565 | train accuracy: 0.7583 | test accuracy: 0.3801\n",
      "epoch: 4 | train loss/sent: 0.4013 | train accuracy: 0.8586 | test accuracy: 0.3783\n",
      "epoch: 5 | train loss/sent: 0.2659 | train accuracy: 0.9079 | test accuracy: 0.3959\n",
      "epoch: 6 | train loss/sent: 0.1747 | train accuracy: 0.9419 | test accuracy: 0.3787\n",
      "epoch: 7 | train loss/sent: 0.1257 | train accuracy: 0.9573 | test accuracy: 0.3805\n",
      "epoch: 8 | train loss/sent: 0.0860 | train accuracy: 0.9702 | test accuracy: 0.3719\n",
      "epoch: 9 | train loss/sent: 0.0652 | train accuracy: 0.9768 | test accuracy: 0.3747\n",
      "epoch: 10 | train loss/sent: 0.0434 | train accuracy: 0.9860 | test accuracy: 0.3887\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Bad pipe message: %s [b'I7{\\xddYY9\\x10\\xe5', b\"\\xee\\x8a\\xf0\\xff\\xe6\\x1a\\xd2\\x00\\x00|\\xc0,\\xc00\\x00\\xa3\\x00\\x9f\\xcc\\xa9\\xcc\\xa8\\xcc\\xaa\\xc0\\xaf\\xc0\\xad\\xc0\\xa3\\xc0\\x9f\\xc0]\\xc0a\\xc0W\\xc0S\\xc0+\\xc0/\\x00\\xa2\\x00\\x9e\\xc0\\xae\\xc0\\xac\\xc0\\xa2\\xc0\\x9e\\xc0\\\\\\xc0`\\xc0V\\xc0R\\xc0$\\xc0(\\x00k\\x00j\\xc0#\\xc0'\\x00g\\x00@\\xc0\\n\\xc0\\x14\\x009\\x008\\xc0\\t\\xc0\\x13\\x003\\x00\", b'\\x9d\\xc0\\xa1\\xc0\\x9d\\xc0Q\\x00\\x9c\\xc0\\xa0\\xc0\\x9c\\xc0P\\x00=\\x00<\\x005\\x00/\\x00\\x9a\\x00\\x99\\xc0\\x07\\xc0\\x11\\x00\\x96\\x00\\x05\\x00\\xff\\x01\\x00\\x00j\\x00\\x00\\x00\\x0e\\x00\\x0c\\x00\\x00']\n",
      "Bad pipe message: %s [b'\\xe1\\x05', b'\\xb0\\x87g\\xc6U\\xd5G\\xa2.\\xd2\\xf7\\x05\\x9fL\\x00\\x00\\xa6\\xc0,\\xc0', b'\\xa3\\x00\\x9f\\xcc\\xa9\\xcc\\xa8\\xcc\\xaa\\xc0\\xaf\\xc0\\xad\\xc0\\xa3\\xc0\\x9f\\xc0]\\xc0a\\xc0W\\xc0S\\xc0+\\xc0/\\x00\\xa2\\x00\\x9e\\xc0\\xae\\xc0\\xac\\xc0\\xa2\\xc0\\x9e\\xc0\\\\\\xc0`\\xc0V']\n",
      "Bad pipe message: %s [b\"\\xc0$\\xc0(\\x00k\\x00j\\xc0s\\xc0w\\x00\\xc4\\x00\\xc3\\xc0#\\xc0'\\x00g\\x00@\\xc0r\\xc0v\\x00\\xbe\\x00\\xbd\\xc0\\n\\xc0\\x14\\x009\\x008\\x00\\x88\\x00\\x87\\xc0\\t\\xc0\\x13\\x003\\x002\\x00\\x9a\\x00\\x99\\x00E\\x00D\\xc0\\x07\\xc0\\x11\\xc0\\x08\\xc0\\x12\\x00\\x16\\x00\\x13\\x00\\x9d\\xc0\\xa1\\xc0\\x9d\\xc0Q\\x00\\x9c\\xc0\\xa0\\xc0\\x9c\\xc0P\\x00=\\x00\\xc0\\x00<\\x00\\xba\\x005\\x00\\x84\\x00/\\x00\\x96\\x00A\\x00\\x05\\x00\\n\\x00\\xff\\x01\\x00\\x00j\\x00\\x00\\x00\\x0e\\x00\\x0c\\x00\\x00\\t127.0.0.1\\x00\\x0b\\x00\\x04\\x03\\x00\\x01\\x02\\x00\\n\\x00\\x0c\\x00\\n\\x00\\x1d\\x00\\x17\\x00\\x1e\\x00\\x19\\x00\\x18\\x00#\\x00\\x00\\x00\\x16\\x00\\x00\\x00\\x17\\x00\\x00\\x00\\r\\x000\\x00.\\x04\\x03\\x05\\x03\\x06\\x03\\x08\\x07\\x08\\x08\\x08\\t\\x08\\n\\x08\"]\n",
      "Bad pipe message: %s [b'\\xc6\\t^\\x9c\\x07\\xc5y\\xd0\\xbeR\\x8b\\xc2\\x94`T\\xd3\\xcel\\x00\\x00>\\xc0\\x14\\xc0\\n\\x009\\x008']\n",
      "Bad pipe message: %s [b'\\x04\\x08\\x05\\x08\\x06\\x04\\x01\\x05\\x01\\x06']\n",
      "Bad pipe message: %s [b'', b'\\x03\\x03']\n",
      "Bad pipe message: %s [b'']\n",
      "Bad pipe message: %s [b'\\x14\\xc6J\\xf8[H\\x91\\xb3\\x8dV^z\\x9dVA\\xf6Tt\\x00\\x00\\xa2\\xc0\\x14\\xc0\\n\\x009\\x008\\x007\\x006\\x00\\x88\\x00\\x87\\x00\\x86\\x00\\x85\\xc0\\x19\\x00:\\x00\\x89\\xc0\\x0f\\xc0\\x05\\x005\\x00\\x84\\xc0\\x13\\xc0\\t\\x003\\x002\\x001\\x000\\x00\\x9a\\x00\\x99\\x00\\x98\\x00\\x97\\x00E']\n",
      "Bad pipe message: %s [b'', b'\\x02']\n",
      "Bad pipe message: %s [b'\\x05\\x02\\x06']\n",
      "Bad pipe message: %s [b'\\xd8j\\x00\\x0be\\x95\\x1d\\t\\xd2\\xa5\\x02\\xda\\x07;\\x93\\x94$\\x96\\x00\\x00>\\xc0\\x14\\xc0\\n\\x009\\x008\\x007\\x006\\xc0\\x0f\\xc0\\x05\\x005\\xc0\\x13\\xc0\\t\\x003\\x002\\x001\\x000\\xc0\\x0e\\xc0\\x04\\x00/\\x00\\x9a\\x00\\x99\\x00\\x98\\x00\\x97\\x00\\x96\\x00\\x07\\xc0']\n",
      "Bad pipe message: %s [b'\\x07\\xc0\\x0c\\xc0\\x02\\x00\\x05\\x00\\x04\\x00\\xff\\x02\\x01\\x00\\x15\\x03']\n",
      "Bad pipe message: %s [b'1\\x84+\\xad\\xe8(\\xa4\\xf2qZ\\xbd\\x06\\x03\\x10u\\xfe\\x18w\\x00\\x00\\xa2\\xc0\\x14\\xc0', b'9\\x008\\x007\\x006\\x00\\x88']\n",
      "Bad pipe message: %s [b'\\x0c0~\\xec\\xf3\\xe2M\\xe5\\xb4\\xbd:v\\xae\\xca\\xec\\xdb\\xb8!\\x00\\x00\\x86\\xc00\\xc0,\\xc0(\\xc0$\\xc0\\x14\\xc0\\n\\x00\\xa5\\x00', b\"\\xa1\\x00\\x9f\\x00k\\x00j\\x00i\\x00h\\x009\\x008\\x007\\x006\\xc02\\xc0.\\xc0*\\xc0&\\xc0\\x0f\\xc0\\x05\\x00\\x9d\\x00=\\x005\\xc0/\\xc0+\\xc0'\\xc0#\\xc0\\x13\\xc0\\t\\x00\\xa4\\x00\\xa2\\x00\\xa0\\x00\\x9e\\x00g\\x00@\\x00?\\x00>\\x003\\x002\\x001\\x000\\xc01\\xc0-\\xc0)\\xc0%\\xc0\\x0e\\xc0\\x04\\x00\\x9c\\x00<\\x00/\\x00\\x9a\\x00\\x99\\x00\\x98\\x00\\x97\\x00\\x96\\x00\\x07\\xc0\\x11\\xc0\\x07\\xc0\\x0c\\xc0\\x02\\x00\\x05\\x00\\x04\\x00\\xff\\x02\\x01\\x00\\x00g\\x00\\x00\\x00\\x0e\\x00\\x0c\\x00\\x00\\t127.0.0.1\\x00\\x0b\\x00\\x04\\x03\\x00\\x01\\x02\\x00\\n\\x00\\x1c\\x00\\x1a\\x00\\x17\\x00\\x19\\x00\\x1c\\x00\\x1b\"]\n",
      "Bad pipe message: %s [b\"j<Xe^\\xc5\\x00c\\x0b\\xdc\\xdc;\\xdf\\xd9\\xdbseB\\x00\\x00\\xf4\\xc00\\xc0,\\xc0(\\xc0$\\xc0\\x14\\xc0\\n\\x00\\xa5\\x00\\xa3\\x00\\xa1\\x00\\x9f\\x00k\\x00j\\x00i\\x00h\\x009\\x008\\x007\\x006\\x00\\x88\\x00\\x87\\x00\\x86\\x00\\x85\\xc0\\x19\\x00\\xa7\\x00m\\x00:\\x00\\x89\\xc02\\xc0.\\xc0*\\xc0&\\xc0\\x0f\\xc0\\x05\\x00\\x9d\\x00=\\x005\\x00\\x84\\xc0/\\xc0+\\xc0'\\xc0#\\xc0\\x13\\xc0\\t\\x00\\xa4\\x00\\xa2\\x00\\xa0\\x00\\x9e\\x00g\\x00@\\x00?\\x00>\\x003\\x002\\x001\\x000\\x00\\x9a\\x00\\x99\\x00\\x98\\x00\\x97\\x00E\"]\n"
     ]
    }
   ],
   "source": [
    "# perform training of the Bow model\n",
    "\n",
    "for epoch in range(10):\n",
    "    # perform training\n",
    "    model.train()\n",
    "    random.shuffle(train_data)\n",
    "    total_loss = 0.0\n",
    "    train_correct = 0\n",
    "    for sentence, tag in train_data:\n",
    "        sentence = torch.tensor(sentence).type(type)\n",
    "        tag = torch.tensor([tag]).type(type)\n",
    "        output = model(sentence)\n",
    "        predicted = torch.argmax(output.data.detach()).item()\n",
    "        \n",
    "        loss = criterion(output, tag)\n",
    "        total_loss += loss.item()\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if predicted == tag: train_correct+=1\n",
    "\n",
    "    # perform testing of the model\n",
    "    model.eval()\n",
    "    test_correct = 0\n",
    "    for sentence, tag in test_data:\n",
    "        sentence = torch.tensor(sentence).type(type)\n",
    "        output = model(sentence)\n",
    "        predicted = torch.argmax(output.data.detach()).item()\n",
    "        if predicted == tag: test_correct += 1\n",
    "    \n",
    "    # print model performance results\n",
    "    log = f'epoch: {epoch+1} | ' \\\n",
    "        f'train loss/sent: {total_loss/len(train_data):.4f} | ' \\\n",
    "        f'train accuracy: {train_correct/len(train_data):.4f} | ' \\\n",
    "        f'test accuracy: {test_correct/len(test_data):.4f}'\n",
    "    \n",
    "    print(log)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12 (main, Apr  5 2022, 06:56:58) \n[GCC 7.5.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
